{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read data to table\n",
    "# Per line is a dict\n",
    "import json\n",
    "import re\n",
    "import numpy as np\n",
    "p = re.compile('(?<!\\\\\\\\)\\'')\n",
    "ret = {}\n",
    "def read_searching_log(output_data):\n",
    "    datas = []\n",
    "    with open(output_data, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "        for line in lines:\n",
    "            # Load data as dict\n",
    "            data = p.sub('\\\"', line)\n",
    "            param = json.loads(data)\n",
    "            for k in param:\n",
    "                if isinstance(param[k], list):\n",
    "                    param[k] = param[k][0]\n",
    "            # Add to datas\n",
    "            datas.append(param)\n",
    "    return datas\n",
    "def print_res_to_table(res, sort_index=\"final_cost\", number=10):\n",
    "    import tabulate\n",
    "    output_table = []\n",
    "    #Add title\n",
    "    if len(res) == 0:\n",
    "        return []\n",
    "    sort_values = []\n",
    "    for row in res:\n",
    "        output_table.append([])\n",
    "        for k in row:\n",
    "            v = row[k]\n",
    "            if isinstance(v, np.ndarray) or isinstance(v, list):\n",
    "                output_table[-1].append(v[0])\n",
    "            else:\n",
    "                if k == \"ATE_T\":\n",
    "                    output_table[-1].append(f\"{v:.3f}m\")\n",
    "                elif k == \"ATE_rot\":\n",
    "                    output_table[-1].append(f\"{v*57.3:.3f}\")\n",
    "                elif k == \"final_cost\":\n",
    "                    output_table[-1].append(f\"{v:.3f}\")\n",
    "                else:\n",
    "                    output_table[-1].append(v)\n",
    "            if k == sort_index:\n",
    "                sort_values.append(v)\n",
    "    output_table = np.array(output_table)\n",
    "    sort_values = np.array(sort_values)\n",
    "    output_table = output_table[sort_values.argsort()]\n",
    "    row0 = []\n",
    "    for key in res[0]:\n",
    "        row0.append(key)\n",
    "    output_table = output_table.tolist()\n",
    "    output_table.insert(0, row0)\n",
    "    output_table = output_table[:number]\n",
    "    return tabulate.tabulate(output_table, tablefmt='html')\n",
    "\n",
    "#Show the importance of the parameters using scatter plot\n",
    "import matplotlib.pyplot as plt\n",
    "def plot_param_vs_cost(ax, datas, param_name, cost_name=\"final_cost\", max_cost=1e6):\n",
    "    x = []\n",
    "    y = []\n",
    "    # Filter out the data with large cost\n",
    "    datas = [data for data in datas if data[cost_name] < max_cost]\n",
    "    for data in datas:\n",
    "        x.append(data[param_name])\n",
    "        y.append(data[cost_name])\n",
    "    ax[0].hist2d(x, y, bins=1000)\n",
    "    # Next we plot the param vs its min cost\n",
    "    min_cost = {}\n",
    "    for data in datas:\n",
    "        if data[param_name] not in min_cost:\n",
    "            min_cost[data[param_name]] = data[cost_name]\n",
    "        else:\n",
    "            min_cost[data[param_name]] = min(min_cost[data[param_name]], data[cost_name])\n",
    "    x = []\n",
    "    y = []\n",
    "    for k in min_cost:\n",
    "        x.append(k)\n",
    "        y.append(min_cost[k])\n",
    "    # Sorting by x\n",
    "    x = np.array(x)\n",
    "    y = np.array(y)\n",
    "    order = x.argsort()\n",
    "    x, y = x[order], y[order]\n",
    "    ax[1].plot(x, y)\n",
    "    for ax_ in ax:\n",
    "        ax_.set_yscale('log')\n",
    "        ax_.set_xscale('log')\n",
    "        ax_.set_xlabel(param_name)\n",
    "        ax_.set_ylabel(cost_name)\n",
    "        ax_.grid(which='both')\n",
    "    ax[1].set_ylabel(cost_name + \" (min)\")\n",
    "plt.rc('figure', figsize=(20, 5))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_data = \"/home/xuhao/data/d2slam/pgo/tum_corr_5/searching-output/search_result.txt\"\n",
    "params = [\"rho_frame_T\", \"rho_frame_theta\", \"eta_k\", \"rho_rot_mat\"]\n",
    "\n",
    "datas = read_searching_log(output_data)\n",
    "display(print_res_to_table(datas, number=5))\n",
    "fig, ax = plt.subplots(2, len(params))\n",
    "for i in range(len(params)):\n",
    "    plot_param_vs_cost(ax[:,i], datas, params[i], max_cost=1000)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "08ce52785f0fedc81003ce387e097a83d6cc9494681cd746006386992005bb71"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
